package org.ovirt.engine.core.uutils.ssh;

import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.io.SequenceInputStream;
import java.nio.charset.StandardCharsets;
import java.security.KeyPair;
import java.security.PublicKey;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

import org.apache.commons.lang.exception.ExceptionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * SSH dialog to be used with SSHClient class.
 *
 * Easy processing of stdin/stdout of SSHClient session. Provided the limitations of the SSH implementation this is the
 * ease the usage of the session.
 *
 * The implementation is a wrapper around SSHClient's executeCommand().
 */
public class SSHDialog implements Closeable {

    private static final int BUFFER_SIZE = 10 * 1024;
    private static final int DEFAULT_SSH_PORT = 22;

    /**
     * Control interface. Callback for the sink.
     */
    public interface Control {
        /**
         * Disconnect session.
         */
        void close() throws IOException;
    }

    /**
     * Dialog sink.
     */
    public interface Sink {
        /**
         * Set control interface.
         *
         * @param control
         *            control.
         */
        void setControl(SSHDialog.Control control);

        /**
         * Set streams to process.
         *
         * @param incoming
         *            incoming stream.
         * @param outgoing
         *            outgoing stream.
         *
         *            Streams are null when sink is removed from session.
         */
        void setStreams(InputStream incoming, OutputStream outgoing);

        /**
         * Start processing. Usually a thread will be created to process streams. This guarantee to be called after
         * setStreams().
         */
        void start();

        /**
         * Stop processing. Called before streams are set to null.
         */
        void stop();
    }

    private static final Logger log = LoggerFactory.getLogger(SSHDialog.class);

    private String host;
    private int port;
    private String user = "root";
    private KeyPair keyPair;
    private String password;
    private long softTimeout = 0;
    private long hardTimeout = 0;

    protected SSHClient client;

    /**
     * Get SSH Client. Used for mocking.
     */
    protected SSHClient getSSHClient() {
        return new SSHClient();
    }

    /**
     * Destructor.
     */
    @Override
    protected void finalize() {
        try {
            close();
        } catch (IOException e) {
            log.error("Finalize exception", ExceptionUtils.getRootCauseMessage(e));
            log.debug("Exception", e);
        }
    }

    /**
     * Get session public key.
     *
     * @return public key or null.
     */
    public PublicKey getPublicKey() {
        if (keyPair == null) {
            return null;
        } else {
            return keyPair.getPublic();
        }
    }

    /**
     * Get host public key.
     */
    public PublicKey getHostKey() throws IOException {
        if (client == null) {
            throw new IOException("Cannot acquire host key, session is disconnected");
        }

        PublicKey hostKey = client.getHostKey();
        if (hostKey == null) {
            throw new IOException("Unable to retrieve host key");
        }

        return hostKey;
    }

    /**
     * Set host to connect to.
     *
     * @param host
     *            host.
     * @param port
     *            port.
     */
    public void setHost(String host, int port) {
        this.host = host;
        this.port = port;
    }

    /**
     * Set host to connect to.
     *
     * @param host
     *            host.
     */
    public void setHost(String host) {
        setHost(host, DEFAULT_SSH_PORT);
    }

    /**
     * Set user to use.
     *
     * @param user
     *            user.
     */
    public void setUser(String user) {
        this.user = user;
    }

    /**
     * Set password to use. If both password and key pair are set key pair is used.
     */
    public void setPassword(String password) {
        this.password = password;
    }

    /**
     * Set key pair. If both password and key pair are set key pair is used.
     *
     * @param keyPair
     *            key pair.
     */
    public void setKeyPair(KeyPair keyPair) {
        this.keyPair = keyPair;
    }

    /**
     * Set soft timeout. Soft timeout is reset when there is session activity.
     *
     * @param timeout
     *            timeout in milliseconds.
     */
    public void setSoftTimeout(long timeout) {
        softTimeout = timeout;
    }

    /**
     * Set hard timeout. Hard timeout is maximum duration of session.
     *
     * @param timeout
     *            timeout in milliseconds.
     */
    public void setHardTimeout(long timeout) {
        hardTimeout = timeout;
    }

    /**
     * Disconnect session.
     */
    public void close() throws IOException {
        if (client != null) {
            client.close();
            client = null;
        }
    }

    /**
     * Connect to host. After connection host fingerprint can be acquired.
     */
    public void connect() throws Exception {
        log.debug(
                "connect enter ({}:{}, {}, {})",
                host,
                port,
                hardTimeout,
                softTimeout);

        try {
            if (client != null) {
                throw new IOException("Already connected");
            }

            client = getSSHClient();
            if (hardTimeout != 0) {
                client.setHardTimeout(hardTimeout);
            }
            if (softTimeout != 0) {
                client.setSoftTimeout(softTimeout);
            }
            client.setHost(host, port);

            log.debug("connecting");
            client.setUser(user);
            client.connect();
        } catch (Exception e) {
            if (client != null) {
                log.debug(
                        "Could not connect to host '{}'",
                        client.getDisplayHost());
            } else {
                log.debug(
                        "Could not connect to host");
            }
            log.debug("Exception", e);
            throw e;
        }
    }

    /**
     * Authenticate.
     */
    public void authenticate() throws Exception {
        client.setPassword(password);
        client.setKeyPair(keyPair);
        client.authenticate();
    }

    /**
     * Execute command.
     *
     * @param sink
     *            sink to use.
     * @param command
     *            command to execute.
     * @param initial
     *            initial input streams to send to host before dialog begins.
     */
    public void executeCommand(
            Sink sink,
            String command,
            InputStream[] initial) throws Exception {

        log.info("SSH execute '{}' '{}'", client.getDisplayHost(), command);

        try (
                final PipedInputStream pinStdin = new PipedInputStream(BUFFER_SIZE);
                final OutputStream poutStdin = new PipedOutputStream(pinStdin);
                final PipedInputStream pinStdout = new PipedInputStream(BUFFER_SIZE);
                final OutputStream poutStdout = new PipedOutputStream(pinStdout);
                final ByteArrayOutputStream stderr = new ConstraintByteArrayOutputStream(1024)) {
            try {
                List<InputStream> stdinList;
                if (initial == null) {
                    stdinList = new LinkedList<>();
                } else {
                    stdinList = new LinkedList<>(Arrays.asList(initial));
                }
                stdinList.add(pinStdin);

                sink.setControl(
                    () -> {
                        if (client != null) {
                            client.close();
                        }
                    });
                sink.setStreams(pinStdout, poutStdin);
                sink.start();

                try {
                    client.executeCommand(
                            command,
                            new SequenceInputStream(Collections.enumeration(stdinList)),
                            poutStdout,
                            stderr);
                } catch (Exception e) {
                    if (stderr.size() == 0) {
                        throw e;
                    }

                    log.error(
                            "Swallowing exception as preferring stderr",
                            ExceptionUtils.getRootCauseMessage(e));
                    log.debug("Exception", e);
                } finally {
                    if (stderr.size() > 0) {
                        throw new RuntimeException(
                                String.format(
                                        "Unexpected error during execution: %1$s",
                                        new String(stderr.toByteArray(), StandardCharsets.UTF_8)));
                    }
                }
            } catch (Exception e) {
                log.error(
                        "SSH error running command {}:'{}': {}",
                        client.getDisplayHost(),
                        command,
                        ExceptionUtils.getRootCauseMessage(e));
                log.debug("Exception", e);
                throw e;
            } finally {
                sink.stop();
                sink.setStreams(null, null);
            }
        }

        log.debug("execute leave");
    }

    /**
     * Send file. Send file using the embedded SSHClient.
     */
    public void sendFile(
            String file1,
            String file2) throws Exception {
        client.sendFile(file1, file2);
    }

    /**
     * Recieve file. Receive file using the embedded SSHClient.
     */
    public void receiveFile(
            String file1,
            String file2) throws Exception {
        client.receiveFile(file1, file2);
    }
}
